Skip to content

[Kernel][Comms] feat: add custom all-gather kernels#1524

Draft
AlpinDale wants to merge 1 commit into
mainfrom
custom_all_gather
Draft

[Kernel][Comms] feat: add custom all-gather kernels#1524
AlpinDale wants to merge 1 commit into
mainfrom
custom_all_gather

Conversation

@AlpinDale
Copy link
Copy Markdown
Collaborator

@AlpinDale AlpinDale commented Sep 19, 2025

We don't really use all-gather all that much, but for context parallel, all-gather is used quite a lot. This adds a fair bit of overhead when doing Context Parallelism, sometimes halving the speed. Currently WIP.

CP will land in #1521

@AlpinDale
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces custom all-gather kernels to optimize performance, particularly for context parallelism. The implementation is a good start, but I've identified several critical issues that need to be addressed. These include function signature mismatches that will cause build failures, a critical bug in handling variable-sized inputs due to incorrect use of std::set, and potential integer overflows from std::accumulate. Please review the detailed comments for fixes.

}

private:
std::set<int> mGroup;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using std::set for mGroup introduces a critical bug and is inefficient. std::set sorts its elements, which will break the correspondence between the ranks and the sizes vector if the input group_ranks is not sorted. This leads to incorrect data gathering when sizes are variable. Additionally, iterating over the set using std::advance inside run_list is inefficient.

Please change mGroup to be a std::vector<int> to preserve the order of ranks and allow for efficient indexing. This requires changes in multiple places:

  1. Change mGroup's type to std::vector<int> here.
  2. Update AllgatherOp constructor (lines 72-74) to accept std::vector<int>.
  3. Update init_custom_ag (lines 148-154) to construct a std::vector<int> from group_ranks and pass it to AllgatherOp.
  4. Update the loop in run_list (lines 118-121) to use mGroup[root_idx] instead of std::advance.
    std::vector<int> mGroup;

Comment thread kernels/ops.h
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);

fptr_t init_custom_ag(const std::vector<int64_t>& group_ranks);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The signature of init_custom_ag is inconsistent with its implementation in custom_all_gather.cu and its declaration in custom_all_gather.cuh. The implementation expects a second argument nccl_comm_ptr of type int64_t, which is missing here. This will cause a linker error.

Suggested change
fptr_t init_custom_ag(const std::vector<int64_t>& group_ranks);
fptr_t init_custom_ag(const std::vector<int64_t>& group_ranks, int64_t nccl_comm_ptr);

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ag), custom_ag) {
// Custom all-gather kernels
custom_ag.def(
"init_custom_ag(int[] group_ranks) -> int");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The TorchScript signature for init_custom_ag is inconsistent with its C++ implementation. It's missing the nccl_comm_ptr argument. This will cause a compilation or runtime error. The signature should accept an additional integer for the communicator pointer.

Suggested change
"init_custom_ag(int[] group_ranks) -> int");
"init_custom_ag(int[] group_ranks, int nccl_comm_ptr) -> int");

std::all_of(sizes.value().begin(), sizes.value().end(),
[&sizes](int64_t size) { return size == sizes.value()[0]; });

int64_t sum_sizes = sizes.has_value() ? std::accumulate(sizes.value().begin(), sizes.value().end(), 0, std::plus<>{}) : 0;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The initial value for std::accumulate is 0, which is an int. Since the sizes vector contains int64_t values, the sum could overflow an int if it exceeds INT_MAX. The accumulator's type is determined by the type of this initial value. To prevent overflow, please use an int64_t initial value.

        int64_t sum_sizes = sizes.has_value() ? std::accumulate(sizes.value().begin(), sizes.value().end(), int64_t{0}, std::plus<>{}) : 0;

AT_CUDA_CHECK(ncclAllGather(input.data_ptr(), output.mutable_data_ptr(), input.numel(), (*getDtypeMap())[type],
mNcclComm, stream));
} else {
size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{});
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The initial value for std::accumulate is 1, which is an int. The product of tensor dimensions can easily overflow an int. The accumulator's type is determined by this initial value. Please use a size_t initial value to prevent potential overflow, as the result is stored in a size_t.

                size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), size_t{1}, std::multiplies<>{});

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant